import torch
import torch.nn as nn
import math
import torch.nn.functional as F
from .utils import *
from .gpn import GPN

class VGPN(GPN):
    def __init__(self, features, n_hidden, edge_indices_no_diag, idx_train, labels, leaky_rate, adj, dropout, T, lamb=1.):
        super(VGPN, self).__init__(features, n_hidden, edge_indices_no_diag, idx_train, labels, leaky_rate, adj, dropout, T)
        self.lamb = lamb
    
    def poisson_conv(self):
        y0 = torch.zeros([self.num_nodes, self.num_classes]).cuda()
        y0[self.idx_train, self.train_labels] = 1
        
        mask = torch.zeros([self.num_nodes, self.num_classes]).cuda()
        mask[self.idx_train] = 1
        
        # edge weights
        P = self.torch_sparse(self.A_ds_no_diag)
        # D = 1e-10 * torch.eye(self.num_nodes).cuda() + self.A_ds_no_diag
        # D = torch.sum(D, 1)
        # D_inv = D ** -1
        # D = torch.diag(D)
        # D_inv = torch.diag(D_inv)
        
        # print(self.A_ds_no_diag.sum(1))
        
        # P = torch.sparse.mm(D_inv, W.t())
        
        ut = torch.zeros([self.num_nodes, self.num_classes]).cuda()
        T = 0
        while T < self.T:
            ut_var = ut - ut.mean(0, keepdim=True)
            ut = torch.sparse.mm(P, ut) + self.lamb * ut_var
            T = T + 1
            if not self.isadj and T == self.T - 3:
                ut = ut + self.ft(self.features)
                ut = F.dropout(ut, self.dropout, training=self.training)
        return ut